From 487d7fc1b424d4d7c0a33be5c02cc6dfe1e689fd Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Fri, 5 Aug 2022 16:34:44 +0900 Subject: [PATCH] support slice in pieces/nbests objects Signed-off-by: Kentaro Hayashi Gbp-Pq: Name 0019-support-slice-in-pieces-nbests-objects.patch --- python/src/sentencepiece/__init__.py | 8 ++++++++ python/src/sentencepiece/sentencepiece.i | 8 ++++++++ python/test/sentencepiece_test.py | 4 ++++ 3 files changed, 20 insertions(+) diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index ce9d60d..cf06830 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -145,6 +145,10 @@ class ImmutableSentencePieceText(object): return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('piece index is out of range') return self.proto._pieces(index) @@ -202,6 +206,10 @@ class ImmutableNBestSentencePieceText(object): return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('nbests index is out of range') return self.proto._nbests(index) diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index e22f763..2ac68a8 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -1293,6 +1293,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('piece index is out of range') return self.proto._pieces(index) @@ -1336,6 +1340,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('nbests index is out of range') return self.proto._nbests(index) diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 6cbe077..92327ac 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -395,6 +395,10 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual( self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text) + # slice + self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces))) + self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests))) + # Japanese offset s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123') surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces] -- 2.30.2